import jax.numpy as np
import importlib
import jax.experimental.optimizers as jaxoptimizers
import scalevi.utils.utils as utils
import scalevi.utils.utils_optimization as utils_optim
import scalevi.optimizers.optimizers as optimizers

def get_optim_arguments(config_dict):
    args = {}
    if config_dict.get('optimizer_step_drop', False):
        args['step_size'] = utils_optim.dropping_stepsize(
                                    config_dict['max_step'],
                                    config_dict.get('optimizer_step_drop_rate', 0.1),
                                    config_dict.get('optimizer_step_drop_count', 3),
                                    config_dict.get('n_iter'))
    else:
        args['step_size'] = config_dict['max_step']
    _momentum_optimizers = ['momentum', 'nesterov', 'custom_momentum_star']
    _adaptive_step_optimizers = ['adam', 'adamax', 'rmsprop', 'custom_adabelief',
                                'rmsprop_momentum', 'custom_amsgrad']
    if config_dict['optimizer'] in  _momentum_optimizers:
        args.update({'mass': config_dict.get('optimizer_mass', 0.9)})
    if config_dict['optimizer'] in _adaptive_step_optimizers:
        args.update({'eps': config_dict.get('optimizer_eps',1e-8)})
    return args

def get_optimizer(config_dict):
    optim =  utils.get_attribute([optimizers, jaxoptimizers], config_dict['optimizer'])
    return optim(**get_optim_arguments(config_dict))
